package com.zbxx.leetcode.practice.tree;


/**
 * @author wanrj
 * @date 2023/1/5
 * @description leetcode 1803 统计异或值在范围内的数对有多少
 */
public class TrieCountPairs {

    public static void main(String[] args) {
        int[] nums = {9, 8, 4, 2, 1};
        int low = 5, high = 14;
        System.out.println(countPairs(nums, low, high));
    }

    final static int level = 16;

    public static int countPairs(int[] nums, int low, int high) {
        TrieNode root = new TrieNode(0);
        if (nums.length <= 1) {
            return 0;
        }
        int result = 0;
        for (int num : nums) {
            int highN = search(root, num, high + 1);
            if (highN == 0) {
                appendNode(root, num);
                continue;
            }
            int lowN = search(root, num, low);
            //把当前值加入到树
            appendNode(root, num);
            result += highN - lowN;
        }
        return result;
    }

    public static void appendNode(TrieNode root, int val) {
        TrieNode tmpRoot = root;
        for (int j = level; j >= 0; j--) {
            tmpRoot = tmpRoot.childVal((val >> j) & 1);
        }
    }

    //找出树里xor 小于等于search值的数量
    public static int search(TrieNode root, int val, int search) {
        int r = 0;
        for (int i = level; i >= 0 && root != null; i--) {
            int valBit = (val >> i) & 1;
            int searchBit = (search >> i) & 1;
            if (searchBit == 1) {
                if (valBit == 1) {
                    r = add(root.right, r);
                    root = root.left;
                } else {
                    r = add(root.left, r);
                    root = root.right;
                }
            } else {
                //切换xor==0的子树计算 找相同的
                if (valBit == 1) {
                    root = root.right;
                } else {
                    root = root.left;
                }
            }
        }
        return r;
    }

    public static int add(TrieNode node, int add) {
        return node == null ? add : add + node.count;
    }


    private final static class TrieNode {

        public int count;

        public TrieNode left;

        public TrieNode right;


        public TrieNode(int count) {
            this.count = count;
        }

        public TrieNode childVal(int val) {
            if (val == 0) {
                if (this.left == null) {
                    this.left = new TrieNode(1);
                } else {
                    left.count++;
                }
                return left;
            }
            if (this.right == null) {
                this.right = new TrieNode(1);
            } else {
                right.count++;
            }
            return right;
        }

    }

}
